from torchmetrics import MaxMetric
from torchmetrics import Accuracy
from torchmetrics import ConfusionMatrix
import pytorch_lightning as pl
import torch
from torch import nn
from omegaconf import OmegaConf
from hydra.utils import instantiate
import numpy as np
import wandb
from pathlib import Path
import os
import logging
import wandb
import copy


log = logging.getLogger(__name__)

# supporting Fedprox

class N_ClassificationModel(pl.LightningModule):
    def __init__(self, cfg,learner_model: nn.Module=None):
        super().__init__()
        self.save_hyperparameters(cfg)
        # print(OmegaConf.to_yaml(self.hparams))

        self.learner_model = learner_model
        # self.num_classes = cfg.model.num_classes
        self.num_classes = cfg.num_classes
        log.info(f"self.num_classes = {self.num_classes}")  # TODO: remove
        self.val_targets = []
        # self.halfway_model: nn.Module = None

        # self.model = instantiate(self.hparams.model.model)

        if learner_model:
            self.model : nn.Module = learner_model
            self.model.requires_grad_(requires_grad=True)
            self.model.train()
            log.info("using the provided model as a starting point")

        else:
            # if not cfg.moon:
                print(f"self.hparams.model: {self.hparams.model}")
                self.model = instantiate(self.hparams.model)
                log.info("initializing a new model..")
            # else:  # moon
            #     print(f"cfg.model._target_: {cfg.model._target_}")
            #     print(f"type(cfg.moon_model): {type(cfg.moon_model)}")
            #     self.model = instantiate(cfg.model._target_)
            #     log.info("initializing a new (moon model)..")

        print(type(self.model))

        # Current client number
        self.current_client_idx = cfg.learner_client

        # loss function
        self.loss = nn.CrossEntropyLoss()

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.per_class_test_acc = []
        self.per_class_val_acc = []



        self.train_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically
        self.val_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically
        self.test_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically

        # for logging best so far validation accuracy
        self.val_acc_best = MaxMetric()
        self.round = cfg.round
        self.currentE = cfg.currentE
        self.track_round = cfg.track_round
        self.Fedprox = cfg.Fedprox
        if self.Fedprox:
            self.teacher_model = copy.deepcopy(self.model)
            self.mu = cfg.mu
            log.info(f">> Doing FedProx, with mu = {cfg.mu}")

    def next_client(self):
        self.current_client_idx += 1



    def forward(self, x):
        return self.model(x)

    def step(self, batch):
        x, y = batch
        if len(y.shape) > 1:
            y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)

        logits = self(x)
        loss = self.loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(self, batch, batch_idx):
        cls_loss, preds, targets = self.step(batch)

        # log train metrics
        acc = self.train_acc(preds, targets)

        loss = cls_loss

        #prox_loss
        if self.Fedprox:
            loss2 = (self.mu/2) * self.difference_models_norm_2(self.model,self.teacher_model)
            loss = loss + loss2


        # conf_mat = self.train_confusion_matrix(preds, targets)
        if self.Fedprox:
            self.log(f"train_client-{self.current_client_idx}-round{self.round}/cls_loss", cls_loss, on_step=False, on_epoch=True, prog_bar=False)
            self.log(f"train_client-{self.current_client_idx}-round{self.round}/prox_loss", loss2, on_step=False, on_epoch=True, prog_bar=False)
        self.log(f"train_client-{self.current_client_idx}-round{self.round}/loss", loss, on_step=False, on_epoch=True, prog_bar=False)

        self.log(f"train_client-{self.current_client_idx}-round{self.round}/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        # self.log(
        #     f"train_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )
        # we can return here dict with any tensors
        # and then read it in some callback or in `training_epoch_end()`` below
        # remember to always return loss from `training_step()` or else backpropagation will fail!
        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)
        # log.info(f"preds: {preds}, targets: {targets}") #TODO: remove this
        self.val_confusion_matrix(preds, targets)
        self.val_targets = targets  # to check in case fair vs. unfair validation #TODO: remove this

        # log val metrics
        acc = self.val_acc(preds, targets)
        log.info(f"val_acc: {acc}")
        # conf_mat = self.val_confusion_matrix(preds, targets)
        self.log(f"val_client-{self.current_client_idx}/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log(f"val_client-{self.current_client_idx}/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        # self.log(
        #     f"val_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )
        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_epoch_end(self, outputs):
        self.acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(self.acc)

        confusion_matrix = self.val_confusion_matrix.compute()
        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_val_acc = np.diag(confusion_matrix.cpu().detach().numpy())

        self.log(f"val_client-{self.current_client_idx}/acc_best", self.val_acc_best.compute(), on_epoch=True,
                 prog_bar=True)

        log.info(f"per_class_val_acc = {self.per_class_val_acc}")
        log.info(f">> val_acc = {self.acc}")
        log.info(f"val targets = {self.val_targets}") # last validation batch, to compare TODO: remove this

        # conf matrix
        self.logger.experiment.log(
            {
                f"val_client-{self.current_client_idx}-round{self.round}/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=np.concatenate([output['targets'].cpu().numpy() for output in outputs]).ravel(),
                    preds=np.concatenate([output['preds'].cpu().numpy() for output in outputs]).ravel(),
                    class_names=None
                )
            },
            # step=self.global_step,
            commit=False
        )

        self.logger.experiment.summary[
            f"client-{self.hparams.learner_client}/val_per_class_acc"
        ] = self.per_class_val_acc

        self.val_acc.reset()
        self.val_confusion_matrix.reset()




    def test_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        self.test_confusion_matrix(preds, targets)

        # log test metrics
        acc = self.test_acc(preds, targets)

        # conf_mat = self.test_confusion_matrix(preds, targets)
        self.log(f"test_client-{self.current_client_idx}-round{self.round}/loss", loss, on_step=False, on_epoch=True)
        self.log(f"test_client-{self.current_client_idx}-round{self.round}/acc", acc, on_step=False, on_epoch=True)
        self.log(f"client-{self.current_client_idx}_best-test-acc-round{self.round}", acc, on_step=False, on_epoch=True)

        # self.log(
        #     f"test_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )

        return {"loss": loss, "preds": preds, "targets": targets}

    def test_epoch_end(self, outputs):
        acc = self.test_acc.compute()
        confusion_matrix = self.test_confusion_matrix.compute()

        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_test_acc = np.diag(confusion_matrix.cpu().detach().numpy())

        self.logger.experiment.log(
            {
                f"test_client-{self.current_client_idx}-round{self.round}/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=np.concatenate([output['targets'].cpu().numpy() for output in outputs]).ravel(),
                    preds=np.concatenate([output['preds'].cpu().numpy() for output in outputs]).ravel(),
                    class_names=None
                )
            },
            # step=self.global_step,
            commit=False
        )

        self.logger.experiment.summary[
            f"client-{self.current_client_idx}-round{self.round}/test_per_class_test_acc"
        ] = self.per_class_test_acc

        log.info(
            f"client-{self.current_client_idx}-round{self.round}/test_per_class_test_acc: {self.per_class_test_acc}")

        self.loggers[-1].log_hyperparams({"test_acc": acc.cpu().item()})
        self.test_acc.reset()
        self.test_confusion_matrix.reset()

    def on_epoch_end(self):
        self.test_acc.reset()
        self.val_acc.reset()
        self.train_acc.reset()
        # if self.track_round:  #TODO ty to fix the issue here
        #     if (self.current_epoch+1) == (self.currentE/2):
        #         self.halfway_model = copy.deepcopy(self.model)
        #         log.info(f">> Currently in the halfway local epoch. Sending the halfway model... ")


        # # reset metrics at the end of every epoch
        # self.train_acc.reset()
        # # self.test_acc.reset()
        # self.val_acc.reset()
        # # self.train_confusion_matrix.reset()
        # # self.val_confusion_matrix.reset()
        # # self.test_confusion_matrix.reset()

    # def on_test_end(self): # on fit end will cause issues since next client reinit the neural network weights
    #     # self.next_client()
    #     self.val_acc_best.reset()

    def difference_models_norm_2( self, model_1, model_2):
        # https://epione.gitlabpages.inria.fr/flhd/federated_learning/FedAvg_FedProx_MNIST_iid_and_noniid.html
        """Return the norm 2 difference between the two model parameters
        """
        tensor_1 = list(model_1.parameters())
        tensor_2 = list(model_2.parameters())
        norm = sum([torch.sum((tensor_1[i] - tensor_2[i]) ** 2)
                    for i in range(len(tensor_1))])
        return norm

    def configure_optimizers(self):
        optim = instantiate(config=self.hparams.optim.optim, params=self.model.parameters())
        return optim